package org.mat.framework.web.interceptor;

import org.mat.framework.core.context.RequestContextConstants;
import org.mat.framework.core.context.RequestContextUtils;
import org.mat.framework.core.context.TraceIdAdaptor;
import org.mat.framework.core.context.TraceIdGenerator;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.MDC;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

/**
 * <p>Title: MDCFilter </p>
 * <p>Date: 2020/3/17 </p>
 * <p>Description: </p>
 *
 * @author sunxinhe
 */
@Slf4j
public class RequestTracingInterceptor implements HandlerInterceptor {

    private TraceIdAdaptor traceIdAdaptor;

    public RequestTracingInterceptor(TraceIdAdaptor traceIdAdaptor) {
        this.traceIdAdaptor = traceIdAdaptor;
    }

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        String traceId = request.getHeader(RequestContextConstants.TRACE_ID_KEY_FOR_HTTP_HEADER);
        if (StringUtils.isNotBlank(traceId)){
            log.info("从上游请求中获取到traceId = {}",traceId);
        }else {
            // 确保traceId非空
            traceId = checkTraceId(traceId);
        }
        // 初始化请求上下文
        RequestContextUtils.init(traceId);
        MDC.put(RequestContextConstants.TRACE_ID_KEY_FOR_MDC, traceId);
        response.addHeader(RequestContextConstants.TRACE_ID_KEY_FOR_HTTP_HEADER, traceId);

        return true;
    }

    private String checkTraceId(String traceId) {
        // 如果 traceIdAdaptor 非空，使用 traceIdAdaptor 初始化
        if (StringUtils.isEmpty(traceId) && traceIdAdaptor != null) {
            traceId = traceIdAdaptor.getTraceId();
            log.info("未检查到traceId，从适配器获取：traceId = {}", traceId);
        }
        // 如果 traceIdAdaptor 获取步骤执行之后，traceId仍为空，使用默认生成器生成
        if (StringUtils.isEmpty(traceId)) {
            traceId = TraceIdGenerator.generate();
            log.info("未检查到traceId，使用默认生成器生成：traceId = {}", traceId);
        }

        return traceId;
    }

    @Override
    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {

    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {

    }
}